//
// Copyright 2023 The GUAC Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build integration

package backend_test

import (
	"context"
	"strings"
	"testing"

	"github.com/google/go-cmp/cmp"
	"github.com/guacsec/guac/internal/testing/ptrfrom"
	"github.com/guacsec/guac/internal/testing/testdata"
	"github.com/guacsec/guac/pkg/assembler/graphql/model"
)

func TestVulnerability(t *testing.T) {
	ctx := context.Background()
	b := setupTest(t)
	tests := []struct {
		Name         string
		Ingests      []*model.VulnerabilityInputSpec
		ExpIngestErr bool
		Query        *model.VulnerabilitySpec
		QueryID      bool
		Exp          []*model.Vulnerability
		ExpQueryErr  bool
	}{
		{
			Name:    "HappyPath",
			Ingests: []*model.VulnerabilityInputSpec{testdata.C1},
			Query:   &model.VulnerabilitySpec{},
			Exp: []*model.Vulnerability{
				{
					Type:             "cve",
					VulnerabilityIDs: []*model.VulnerabilityID{testdata.C1out},
				},
			},
		},
		{
			Name:    "Multiple",
			Ingests: []*model.VulnerabilityInputSpec{testdata.C1, testdata.C2},
			Query:   &model.VulnerabilitySpec{},
			Exp: []*model.Vulnerability{
				{
					Type:             "cve",
					VulnerabilityIDs: []*model.VulnerabilityID{testdata.C2out, testdata.C1out},
				},
			},
		},
		{
			Name:    "Duplicates",
			Ingests: []*model.VulnerabilityInputSpec{testdata.C1, testdata.C1, testdata.C1},
			Query: &model.VulnerabilitySpec{
				VulnerabilityID: ptrfrom.String("cve-2019-13110"),
			},
			Exp: []*model.Vulnerability{
				{
					Type:             "cve",
					VulnerabilityIDs: []*model.VulnerabilityID{testdata.C1out},
				},
			},
		},
		{
			Name:    "Query by type - cve",
			Ingests: []*model.VulnerabilityInputSpec{testdata.C1, testdata.C2, testdata.C3, testdata.G1, testdata.O1},
			Query: &model.VulnerabilitySpec{
				Type: ptrfrom.String("cve"),
			},
			Exp: []*model.Vulnerability{
				{
					Type:             "cve",
					VulnerabilityIDs: []*model.VulnerabilityID{testdata.C3out, testdata.C2out, testdata.C1out},
				},
			},
		},
		{
			Name:    "Query by type - ghsa",
			Ingests: []*model.VulnerabilityInputSpec{testdata.C1, testdata.C2, testdata.C3, testdata.G2, testdata.G3, testdata.O1},
			Query: &model.VulnerabilitySpec{
				Type: ptrfrom.String("ghsa"),
			},
			Exp: []*model.Vulnerability{
				{
					Type:             "ghsa",
					VulnerabilityIDs: []*model.VulnerabilityID{testdata.G3out, testdata.G2out, testdata.G1out},
				},
			},
		},
		{
			Name:    "Query by type - osv",
			Ingests: []*model.VulnerabilityInputSpec{testdata.C1, testdata.C2, testdata.C3, testdata.G3, testdata.O1, testdata.O2, testdata.O3},
			Query: &model.VulnerabilitySpec{
				Type: ptrfrom.String("osv"),
			},
			Exp: []*model.Vulnerability{
				{
					Type:             "osv",
					VulnerabilityIDs: []*model.VulnerabilityID{testdata.O3out, testdata.O2out, testdata.O1out},
				},
			},
		},
		{
			Name:    "Query by type - noVuln",
			Ingests: []*model.VulnerabilityInputSpec{testdata.NoVulnInput},
			Query: &model.VulnerabilitySpec{
				Type: ptrfrom.String("noVuln"),
			},
			Exp: []*model.Vulnerability{
				{
					Type:             "novuln",
					VulnerabilityIDs: []*model.VulnerabilityID{testdata.NoVulnOut},
				},
			},
		},
		{
			Name:    "Query by type - noVuln with boolean",
			Ingests: []*model.VulnerabilityInputSpec{testdata.NoVulnInput},
			Query: &model.VulnerabilitySpec{
				NoVuln: ptrfrom.Bool(true),
			},
			Exp: []*model.Vulnerability{
				{
					Type:             "novuln",
					VulnerabilityIDs: []*model.VulnerabilityID{testdata.NoVulnOut},
				},
			},
		},
		{
			Name:    "Query by vulnID",
			Ingests: []*model.VulnerabilityInputSpec{testdata.C1, testdata.C2, testdata.C3},
			Query: &model.VulnerabilitySpec{
				Type:            ptrfrom.String("cve"),
				VulnerabilityID: ptrfrom.String("CVE-2014-8140"),
			},
			Exp: []*model.Vulnerability{
				{
					Type:             "cve",
					VulnerabilityIDs: []*model.VulnerabilityID{testdata.C3out},
				},
			},
		},
		{
			Name:    "Query by vulnID - noVuln",
			Ingests: []*model.VulnerabilityInputSpec{testdata.C1, testdata.C2, testdata.C3, testdata.NoVulnInput},
			Query: &model.VulnerabilitySpec{
				Type: ptrfrom.String("noVuln"),
			},
			Exp: []*model.Vulnerability{
				{
					Type:             "novuln",
					VulnerabilityIDs: []*model.VulnerabilityID{testdata.NoVulnOut},
				},
			},
		},
		{
			Name:    "Query on ID",
			Ingests: []*model.VulnerabilityInputSpec{testdata.C1},
			QueryID: true,
			Exp: []*model.Vulnerability{
				{
					Type:             "cve",
					VulnerabilityIDs: []*model.VulnerabilityID{testdata.C1out},
				},
			},
		},
		{
			Name:    "Query none",
			Ingests: []*model.VulnerabilityInputSpec{testdata.C1, testdata.C2, testdata.C3},
			Query: &model.VulnerabilitySpec{
				Type: ptrfrom.String("5258"),
			},
			Exp: nil,
		},
		{
			Name:    "Query none ID",
			Ingests: []*model.VulnerabilityInputSpec{testdata.C1, testdata.C2, testdata.C3},
			Query: &model.VulnerabilitySpec{
				ID: ptrfrom.String("bbcc0454-d1ca-484c-b26f-e7b6576ef04e"),
			},
			Exp: nil,
		},
	}
	for _, test := range tests {
		t.Run(test.Name, func(t *testing.T) {
			for _, i := range test.Ingests {
				vulnIDs, err := b.IngestVulnerability(ctx, model.IDorVulnerabilityInput{VulnerabilityInput: i})
				if (err != nil) != test.ExpIngestErr {
					t.Fatalf("did not get expected ingest error, want: %v, got: %v", test.ExpIngestErr, err)
				}
				if err != nil {
					return
				}
				if test.QueryID {
					test.Query = &model.VulnerabilitySpec{
						ID: ptrfrom.String(vulnIDs.VulnerabilityNodeID),
					}
				}
			}
			got, err := b.VulnerabilityList(ctx, *test.Query, nil, nil)
			if (err != nil) != test.ExpQueryErr {
				t.Fatalf("did not get expected query error, want: %v, got: %v", test.ExpQueryErr, err)
			}
			if err != nil {
				return
			}
			var returnedObjects []*model.Vulnerability
			if got != nil {
				for _, obj := range got.Edges {
					returnedObjects = append(returnedObjects, obj.Node)
				}
			}
			if diff := cmp.Diff(test.Exp, convertToVulnTrie(returnedObjects), commonOpts); diff != "" {
				t.Errorf("Unexpected results. (-want +got):\n%s", diff)
			}
		})
	}
}

func convertToVulnTrie(vulnObjs []*model.Vulnerability) []*model.Vulnerability {
	vulnTypes := map[string][]*model.VulnerabilityID{}

	for _, vulnObj := range vulnObjs {
		typeString := vulnObj.Type + "," + strings.Join([]string{"vulnerability_types", vulnObj.Type}, ":")
		vulnID := &model.VulnerabilityID{
			ID:              vulnObj.VulnerabilityIDs[0].ID,
			VulnerabilityID: vulnObj.VulnerabilityIDs[0].VulnerabilityID,
		}
		if _, ok := vulnTypes[typeString]; ok {
			vulnTypes[typeString] = append(vulnTypes[typeString], vulnID)
		} else {
			var vulnIDs []*model.VulnerabilityID
			vulnIDs = append(vulnIDs, vulnID)
			vulnTypes[typeString] = vulnIDs
		}
	}
	var vulnerabilities []*model.Vulnerability
	for vulnType, vulnIDs := range vulnTypes {
		typeValues := strings.Split(vulnType, ",")
		vuln := &model.Vulnerability{
			ID:               typeValues[1],
			Type:             typeValues[0],
			VulnerabilityIDs: vulnIDs,
		}
		vulnerabilities = append(vulnerabilities, vuln)
	}
	return vulnerabilities
}

func TestIngestVulnerabilities(t *testing.T) {
	ctx := context.Background()
	b := setupTest(t)
	tests := []struct {
		name    string
		ingests []*model.IDorVulnerabilityInput
		exp     []*model.Vulnerability
	}{{
		name:    "Multiple",
		ingests: []*model.IDorVulnerabilityInput{{VulnerabilityInput: testdata.C1}, {VulnerabilityInput: testdata.O1}, {VulnerabilityInput: testdata.G1}},
	}}
	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			got, err := b.IngestVulnerabilities(ctx, test.ingests)
			if err != nil {
				t.Fatalf("ingest error: %v", err)
				return
			}
			if len(got) != len(test.ingests) {
				t.Errorf("Unexpected number of results. Wanted: %d, got %d", len(test.ingests), len(got))
			}
		})
	}
}
